{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "WARNING:tensorflow:From /home/husein/.local/lib/python3.6/site-packages/tensor2tensor/utils/optimize.py:187: The name tf.train.Optimizer is deprecated. Please use tf.compat.v1.train.Optimizer instead.\n",
      "\n",
      "WARNING:tensorflow:From /home/husein/.local/lib/python3.6/site-packages/tensor2tensor/models/research/neural_stack.py:52: The name tf.nn.rnn_cell.RNNCell is deprecated. Please use tf.compat.v1.nn.rnn_cell.RNNCell instead.\n",
      "\n",
      "WARNING:tensorflow:From /home/husein/.local/lib/python3.6/site-packages/tensor2tensor/utils/trainer_lib.py:111: The name tf.OptimizerOptions is deprecated. Please use tf.compat.v1.OptimizerOptions instead.\n",
      "\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "WARNING:tensorflow:From /home/husein/.local/lib/python3.6/site-packages/tensor2tensor/utils/trainer_lib.py:111: The name tf.OptimizerOptions is deprecated. Please use tf.compat.v1.OptimizerOptions instead.\n",
      "\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "WARNING:tensorflow:From /home/husein/.local/lib/python3.6/site-packages/tensorflow_gan/python/estimator/tpu_gan_estimator.py:42: The name tf.estimator.tpu.TPUEstimator is deprecated. Please use tf.compat.v1.estimator.tpu.TPUEstimator instead.\n",
      "\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "WARNING:tensorflow:From /home/husein/.local/lib/python3.6/site-packages/tensorflow_gan/python/estimator/tpu_gan_estimator.py:42: The name tf.estimator.tpu.TPUEstimator is deprecated. Please use tf.compat.v1.estimator.tpu.TPUEstimator instead.\n",
      "\n"
     ]
    }
   ],
   "source": [
    "from tensor2tensor import models\n",
    "from tensor2tensor import problems\n",
    "from tensor2tensor.layers import common_layers\n",
    "from tensor2tensor.utils import trainer_lib\n",
    "from tensor2tensor.utils import t2t_model\n",
    "from tensor2tensor.utils import registry\n",
    "from tensor2tensor.utils import metrics\n",
    "from tensor2tensor.data_generators import problem\n",
    "from tensor2tensor.data_generators import text_problems\n",
    "from tensor2tensor.data_generators import translate\n",
    "from tensor2tensor.utils import registry"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [],
   "source": [
    "@registry.register_problem\n",
    "class TRANSLATION32k(translate.TranslateProblem):\n",
    "\n",
    "    @property\n",
    "    def additional_training_datasets(self):\n",
    "        \"\"\"Allow subclasses to add training datasets.\"\"\"\n",
    "        return []"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [],
   "source": [
    "PROBLEM = 'translatio_n32k'\n",
    "problem = problems.problem(PROBLEM)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "('t2t/data/vocab.translatio_n32k.32768.subwords',\n",
       " 't2t/train-base/model.ckpt-75000')"
      ]
     },
     "execution_count": 4,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "import tensorflow as tf\n",
    "import os\n",
    "\n",
    "vocab_file = \"t2t/data/vocab.translatio_n32k.32768.subwords\"\n",
    "ckpt_path = tf.train.latest_checkpoint(os.path.join('t2t/train-base'))\n",
    "vocab_file, ckpt_path"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {},
   "outputs": [],
   "source": [
    "from t import text_encoder\n",
    "\n",
    "encoder = text_encoder.SubwordTextEncoder(vocab_file)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 48,
   "metadata": {},
   "outputs": [],
   "source": [
    "from tensor2tensor.layers import modalities\n",
    "from tensor2tensor.layers import common_layers\n",
    "\n",
    "def top_p_logits(logits, p):\n",
    "    with tf.variable_scope('top_p_logits'):\n",
    "        logits_sort = tf.sort(logits, direction = 'DESCENDING')\n",
    "        probs_sort = tf.nn.softmax(logits_sort)\n",
    "        probs_sums = tf.cumsum(probs_sort, axis = 1, exclusive = True)\n",
    "        logits_masked = tf.where(\n",
    "            probs_sums < p, logits_sort, tf.ones_like(logits_sort) * 1000\n",
    "        )  # [batchsize, vocab]\n",
    "        min_logits = tf.reduce_min(\n",
    "            logits_masked, axis = 1, keepdims = True\n",
    "        )  # [batchsize, 1]\n",
    "        return tf.where(\n",
    "            logits < min_logits,\n",
    "            tf.ones_like(logits, dtype = logits.dtype) * -1e10,\n",
    "            logits,\n",
    "        )\n",
    "\n",
    "\n",
    "def sample(translate_model, features):\n",
    "    logits, losses = translate_model(features)\n",
    "    logits_shape = common_layers.shape_list(logits)\n",
    "    logits_p = logits[:,0,:,0,:] / translate_model.hparams.sampling_temp\n",
    "    logits_p = top_p_logits(logits_p, translate_model.hparams.top_p)\n",
    "    reshaped_logits = tf.reshape(logits_p, [-1, logits_shape[-1]])\n",
    "    choices = tf.multinomial(reshaped_logits, 1)\n",
    "    samples = tf.reshape(choices, logits_shape[:-1])\n",
    "    return samples, logits, losses\n",
    "\n",
    "def nucleus_sampling(translate_model, features, decode_length):\n",
    "    \"\"\"A slow greedy inference method.\n",
    "    Quadratic time in decode_length.\n",
    "    Args:\n",
    "      features: an map of string to `Tensor`\n",
    "      decode_length: an integer.  How many additional timesteps to decode.\n",
    "    Returns:\n",
    "      A dict of decoding results {\n",
    "          \"outputs\": integer `Tensor` of decoded ids of shape\n",
    "              [batch_size, <= decode_length] if beam_size == 1 or\n",
    "              [batch_size, top_beams, <= decode_length]\n",
    "          \"scores\": None\n",
    "          \"logits\": `Tensor` of shape [batch_size, time, 1, 1, vocab_size].\n",
    "          \"losses\": a dictionary: {loss-name (string): floating point `Scalar`}\n",
    "      }\n",
    "    \"\"\"\n",
    "    if not features:\n",
    "        features = {}\n",
    "    inputs_old = None\n",
    "    if 'inputs' in features and len(features['inputs'].shape) < 4:\n",
    "        inputs_old = features['inputs']\n",
    "        features['inputs'] = tf.expand_dims(features['inputs'], 2)\n",
    "    # Save the targets in a var and reassign it after the tf.while loop to avoid\n",
    "    # having targets being in a 'while' frame. This ensures targets when used\n",
    "    # in metric functions stays in the same frame as other vars.\n",
    "    targets_old = features.get('targets', None)\n",
    "\n",
    "    target_modality = translate_model._problem_hparams.modality['targets']\n",
    "\n",
    "    def infer_step(recent_output, recent_logits, unused_loss):\n",
    "        \"\"\"Inference step.\"\"\"\n",
    "        if not tf.executing_eagerly():\n",
    "            if translate_model._target_modality_is_real:\n",
    "                dim = translate_model._problem_hparams.vocab_size['targets']\n",
    "                if dim is not None and hasattr(\n",
    "                    translate_model._hparams, 'vocab_divisor'\n",
    "                ):\n",
    "                    dim += (-dim) % translate_model._hparams.vocab_divisor\n",
    "                recent_output.set_shape([None, None, None, dim])\n",
    "            else:\n",
    "                recent_output.set_shape([None, None, None, 1])\n",
    "        padded = tf.pad(recent_output, [[0, 0], [0, 1], [0, 0], [0, 0]])\n",
    "        features['targets'] = padded\n",
    "        # This is inefficient in that it generates samples at all timesteps,\n",
    "        # not just the last one, except if target_modality is pointwise.\n",
    "        samples, logits, losses = sample(translate_model, features)\n",
    "        # Concatenate the already-generated recent_output with last timestep\n",
    "        # of the newly-generated samples.\n",
    "        top = translate_model._hparams.top.get(\n",
    "            'targets', modalities.get_top(target_modality)\n",
    "        )\n",
    "        if getattr(top, 'pointwise', False):\n",
    "            cur_sample = samples[:, -1, :, :]\n",
    "        else:\n",
    "            cur_sample = samples[\n",
    "                :, common_layers.shape_list(recent_output)[1], :, :\n",
    "            ]\n",
    "        if translate_model._target_modality_is_real:\n",
    "            cur_sample = tf.expand_dims(cur_sample, axis = 1)\n",
    "            samples = tf.concat([recent_output, cur_sample], axis = 1)\n",
    "        else:\n",
    "            cur_sample = tf.to_int64(tf.expand_dims(cur_sample, axis = 1))\n",
    "            samples = tf.concat([recent_output, cur_sample], axis = 1)\n",
    "            if not tf.executing_eagerly():\n",
    "                samples.set_shape([None, None, None, 1])\n",
    "\n",
    "        # Assuming we have one shard for logits.\n",
    "        logits = tf.concat([recent_logits, logits[:, -1:]], 1)\n",
    "        loss = sum([l for l in losses.values() if l is not None])\n",
    "        return samples, logits, loss\n",
    "\n",
    "    # Create an initial output tensor. This will be passed\n",
    "    # to the infer_step, which adds one timestep at every iteration.\n",
    "    if 'partial_targets' in features:\n",
    "        initial_output = tf.to_int64(features['partial_targets'])\n",
    "        while len(initial_output.get_shape().as_list()) < 4:\n",
    "            initial_output = tf.expand_dims(initial_output, 2)\n",
    "        batch_size = common_layers.shape_list(initial_output)[0]\n",
    "    else:\n",
    "        batch_size = common_layers.shape_list(features['inputs'])[0]\n",
    "        if translate_model._target_modality_is_real:\n",
    "            dim = translate_model._problem_hparams.vocab_size['targets']\n",
    "            if dim is not None and hasattr(\n",
    "                translate_model._hparams, 'vocab_divisor'\n",
    "            ):\n",
    "                dim += (-dim) % translate_model._hparams.vocab_divisor\n",
    "            initial_output = tf.zeros(\n",
    "                (batch_size, 0, 1, dim), dtype = tf.float32\n",
    "            )\n",
    "        else:\n",
    "            initial_output = tf.zeros((batch_size, 0, 1, 1), dtype = tf.int64)\n",
    "    # Hack: foldl complains when the output shape is less specified than the\n",
    "    # input shape, so we confuse it about the input shape.\n",
    "    initial_output = tf.slice(\n",
    "        initial_output, [0, 0, 0, 0], common_layers.shape_list(initial_output)\n",
    "    )\n",
    "    target_modality = translate_model._problem_hparams.modality['targets']\n",
    "    if (\n",
    "        target_modality == modalities.ModalityType.CLASS_LABEL\n",
    "        or translate_model._problem_hparams.get('regression_targets')\n",
    "    ):\n",
    "        decode_length = 1\n",
    "    else:\n",
    "        if 'partial_targets' in features:\n",
    "            prefix_length = common_layers.shape_list(\n",
    "                features['partial_targets']\n",
    "            )[1]\n",
    "        else:\n",
    "            prefix_length = common_layers.shape_list(features['inputs'])[1]\n",
    "        decode_length = prefix_length + decode_length\n",
    "\n",
    "    # Initial values of result, logits and loss.\n",
    "    result = initial_output\n",
    "    vocab_size = translate_model._problem_hparams.vocab_size['targets']\n",
    "    if vocab_size is not None and hasattr(\n",
    "        translate_model._hparams, 'vocab_divisor'\n",
    "    ):\n",
    "        vocab_size += (-vocab_size) % translate_model._hparams.vocab_divisor\n",
    "    if translate_model._target_modality_is_real:\n",
    "        logits = tf.zeros((batch_size, 0, 1, vocab_size))\n",
    "        logits_shape_inv = [None, None, None, None]\n",
    "    else:\n",
    "        # tensor of shape [batch_size, time, 1, 1, vocab_size]\n",
    "        logits = tf.zeros((batch_size, 0, 1, 1, vocab_size))\n",
    "        logits_shape_inv = [None, None, None, None, None]\n",
    "    if not tf.executing_eagerly():\n",
    "        logits.set_shape(logits_shape_inv)\n",
    "\n",
    "    loss = 0.0\n",
    "\n",
    "    def while_exit_cond(\n",
    "        result, logits, loss\n",
    "    ):  # pylint: disable=unused-argument\n",
    "        \"\"\"Exit the loop either if reach decode_length or EOS.\"\"\"\n",
    "        length = common_layers.shape_list(result)[1]\n",
    "\n",
    "        not_overflow = length < decode_length\n",
    "\n",
    "        if translate_model._problem_hparams.stop_at_eos:\n",
    "\n",
    "            def fn_not_eos():\n",
    "                return tf.not_equal(  # Check if the last predicted element is a EOS\n",
    "                    tf.squeeze(result[:, -1, :, :]), text_encoder.EOS_ID\n",
    "                )\n",
    "\n",
    "            not_eos = tf.cond(\n",
    "                # We only check for early stopping if there is at least 1 element (\n",
    "                # otherwise not_eos will crash).\n",
    "                tf.not_equal(length, 0),\n",
    "                fn_not_eos,\n",
    "                lambda: True,\n",
    "            )\n",
    "\n",
    "            return tf.cond(\n",
    "                tf.equal(batch_size, 1),\n",
    "                # If batch_size == 1, we check EOS for early stopping.\n",
    "                lambda: tf.logical_and(not_overflow, not_eos),\n",
    "                # Else, just wait for max length\n",
    "                lambda: not_overflow,\n",
    "            )\n",
    "        return not_overflow\n",
    "\n",
    "    result, logits, loss = tf.while_loop(\n",
    "        while_exit_cond,\n",
    "        infer_step,\n",
    "        [result, logits, loss],\n",
    "        shape_invariants = [\n",
    "            tf.TensorShape([None, None, None, None]),\n",
    "            tf.TensorShape(logits_shape_inv),\n",
    "            tf.TensorShape([]),\n",
    "        ],\n",
    "        back_prop = False,\n",
    "        parallel_iterations = 1,\n",
    "    )\n",
    "    if inputs_old is not None:  # Restore to not confuse Estimator.\n",
    "        features['inputs'] = inputs_old\n",
    "    # Reassign targets back to the previous value.\n",
    "    if targets_old is not None:\n",
    "        features['targets'] = targets_old\n",
    "    losses = {'training': loss}\n",
    "    if 'partial_targets' in features:\n",
    "        partial_target_length = common_layers.shape_list(\n",
    "            features['partial_targets']\n",
    "        )[1]\n",
    "        result = tf.slice(\n",
    "            result, [0, partial_target_length, 0, 0], [-1, -1, -1, -1]\n",
    "        )\n",
    "    return {\n",
    "        'outputs': result,\n",
    "        'scores': None,\n",
    "        'logits': logits,\n",
    "        'losses': losses,\n",
    "    }"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 65,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "INFO:tensorflow:Setting T2TModel mode to 'infer'\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "INFO:tensorflow:Setting T2TModel mode to 'infer'\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "INFO:tensorflow:Setting hparams.dropout to 0.0\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "INFO:tensorflow:Setting hparams.dropout to 0.0\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "INFO:tensorflow:Setting hparams.label_smoothing to 0.0\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "INFO:tensorflow:Setting hparams.label_smoothing to 0.0\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "INFO:tensorflow:Setting hparams.layer_prepostprocess_dropout to 0.0\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "INFO:tensorflow:Setting hparams.layer_prepostprocess_dropout to 0.0\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "INFO:tensorflow:Setting hparams.symbol_dropout to 0.0\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "INFO:tensorflow:Setting hparams.symbol_dropout to 0.0\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "INFO:tensorflow:Setting hparams.attention_dropout to 0.0\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "INFO:tensorflow:Setting hparams.attention_dropout to 0.0\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "INFO:tensorflow:Setting hparams.relu_dropout to 0.0\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "INFO:tensorflow:Setting hparams.relu_dropout to 0.0\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "INFO:tensorflow:Using variable initializer: uniform_unit_scaling\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "INFO:tensorflow:Using variable initializer: uniform_unit_scaling\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "INFO:tensorflow:Transforming feature 'inputs' with symbol_modality_25385_512.bottom\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "INFO:tensorflow:Transforming feature 'inputs' with symbol_modality_25385_512.bottom\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "INFO:tensorflow:Transforming feature 'targets' with symbol_modality_25385_512.targets_bottom\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "INFO:tensorflow:Transforming feature 'targets' with symbol_modality_25385_512.targets_bottom\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "INFO:tensorflow:Building model body\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "INFO:tensorflow:Building model body\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "INFO:tensorflow:Transforming body output with symbol_modality_25385_512.top\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "INFO:tensorflow:Transforming body output with symbol_modality_25385_512.top\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "INFO:tensorflow:Using variable initializer: uniform_unit_scaling\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "INFO:tensorflow:Using variable initializer: uniform_unit_scaling\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "INFO:tensorflow:Transforming feature 'inputs' with symbol_modality_25385_512.bottom\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "INFO:tensorflow:Transforming feature 'inputs' with symbol_modality_25385_512.bottom\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "INFO:tensorflow:Transforming feature 'targets' with symbol_modality_25385_512.targets_bottom\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "INFO:tensorflow:Transforming feature 'targets' with symbol_modality_25385_512.targets_bottom\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "INFO:tensorflow:Building model body\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "INFO:tensorflow:Building model body\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "INFO:tensorflow:Transforming body output with symbol_modality_25385_512.top\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "INFO:tensorflow:Transforming body output with symbol_modality_25385_512.top\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "INFO:tensorflow:Using variable initializer: uniform_unit_scaling\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "INFO:tensorflow:Using variable initializer: uniform_unit_scaling\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "INFO:tensorflow:Transforming feature 'inputs' with symbol_modality_25385_512.bottom\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "INFO:tensorflow:Transforming feature 'inputs' with symbol_modality_25385_512.bottom\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "INFO:tensorflow:Transforming feature 'targets' with symbol_modality_25385_512.targets_bottom\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "INFO:tensorflow:Transforming feature 'targets' with symbol_modality_25385_512.targets_bottom\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "INFO:tensorflow:Building model body\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "INFO:tensorflow:Building model body\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "INFO:tensorflow:Transforming body output with symbol_modality_25385_512.top\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "INFO:tensorflow:Transforming body output with symbol_modality_25385_512.top\n"
     ]
    }
   ],
   "source": [
    "class Model:\n",
    "    def __init__(self, HPARAMS = \"transformer_base\", DATA_DIR = 't2t/data'):\n",
    "        \n",
    "        self.X = tf.placeholder(tf.int32, [None, None])\n",
    "        self.Y = tf.placeholder(tf.int32, [None, None])\n",
    "        self.top_p = tf.placeholder(tf.float32, None)\n",
    "        \n",
    "        self.X_seq_len = tf.count_nonzero(self.X, 1, dtype=tf.int32)\n",
    "        maxlen_decode = 50 + tf.reduce_max(self.X_seq_len)\n",
    "        \n",
    "        x = tf.expand_dims(tf.expand_dims(self.X, -1), -1)\n",
    "        y = tf.expand_dims(tf.expand_dims(self.Y, -1), -1)\n",
    "        \n",
    "        features = {\n",
    "            \"inputs\": x,\n",
    "            \"targets\": y,\n",
    "            \"target_space_id\": tf.constant(1, dtype=tf.int32),\n",
    "        }\n",
    "        self.features = features\n",
    "        \n",
    "        Modes = tf.estimator.ModeKeys\n",
    "        hparams = trainer_lib.create_hparams(HPARAMS, data_dir=DATA_DIR, problem_name=PROBLEM)\n",
    "        translate_model = registry.model('transformer')(hparams, Modes.PREDICT)\n",
    "        self.translate_model = translate_model\n",
    "        logits, _ = translate_model(features)\n",
    "        translate_model.hparams.top_p = self.top_p\n",
    "        \n",
    "        with tf.variable_scope(tf.get_variable_scope(), reuse=True):\n",
    "            self.fast_result = translate_model._greedy_infer(features, maxlen_decode)[\"outputs\"]\n",
    "            self.beam_result = translate_model._beam_decode_slow(\n",
    "                features, maxlen_decode, beam_size=5, \n",
    "                top_beams=1, alpha=1.0)[\"outputs\"]\n",
    "            self.nucleus_result = nucleus_sampling(translate_model, features, maxlen_decode)[\"outputs\"]\n",
    "            self.nucleus_result = self.nucleus_result[:,:,0,0]\n",
    "        \n",
    "        self.fast_result = tf.identity(self.fast_result, name = 'greedy')\n",
    "        self.beam_result = tf.identity(self.beam_result, name = 'beam')\n",
    "        self.nucleus_result = tf.identity(self.nucleus_result, name = 'nucleus')\n",
    "        \n",
    "tf.reset_default_graph()\n",
    "sess = tf.InteractiveSession()\n",
    "model = Model()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 66,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "INFO:tensorflow:Restoring parameters from t2t/train-base/model.ckpt-75000\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "INFO:tensorflow:Restoring parameters from t2t/train-base/model.ckpt-75000\n"
     ]
    }
   ],
   "source": [
    "var_lists = tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES)\n",
    "saver = tf.train.Saver(var_list = var_lists)\n",
    "saver.restore(sess, ckpt_path)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 68,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "('Agama biasanya bermaksud kepercayaan kepada Tuhan, atau kekuatan ghaib dan ghaib seperti Tuhan, serta amalan dan institusi yang berkaitan dengan kepercayaan itu. Agama dan kepercayaan adalah dua perkara yang sangat relevan. Tetapi Agama mempunyai makna yang lebih luas, yang merujuk kepada sistem kepercayaan yang kohesif, dan kepercayaan ini adalah mengenai aspek ilahi<EOS>',\n",
       " 'Agama biasanya bermaksud kepercayaan kepada Tuhan, atau kekuatan ghaib dan ghaib seperti Tuhan, serta amalan dan institusi yang berkaitan dengan kepercayaan itu. Agama dan kepercayaan adalah dua perkara yang sangat relevan. Tetapi Agama mempunyai makna yang lebih luas, yang merujuk kepada sistem kepercayaan kohesif, dan kepercayaan ini adalah mengenai aspek ilahi<EOS><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad>',\n",
       " 'Agama biasanya bermaksud kepercayaan kepada Tuhan, atau kekuatan ghaib dan NarcPeter seperti iblis, serta amalan dan institusi yang berkaitan dengan kepercayaan itu. Agama dan kepercayaan adalah dua perkara yang sangat relevan. Tetapi agama mempunyai makna yang lebih luas, yang merujuk kepada sistem kepercayaan yang kohesif, dan kepercayaan ini adalah mengenai aspek ilahi<EOS>')"
      ]
     },
     "execution_count": 68,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "string = 'Religion usually means trust in God, or a supernatural and supernatural power like God, as well as practices and institutions associated with that belief. Religion and belief are two very relevant things. But Religion has a broader meaning, which refers to a system of cohesive belief, and this belief is about the divine aspect'\n",
    "encoded = encoder.encode(string) + [1]\n",
    "f, b, n = sess.run([model.fast_result, model.beam_result, model.nucleus_result], feed_dict = {model.X: [encoded], model.top_p: 0.7})\n",
    "encoder.decode(f[0]), encoder.decode(b[0]), encoder.decode(n[0])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 70,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "('Articulate (Bahasa yang diucapkan dengan baik, Ekspresif) <> Hmmm ya Horacio Calcaterra adalah seorang ahli sukan yang kini bermain untuk Sporting Cristal yang terletak di Torneo Descentralizado. <> Dia kini bermain untuk Sporting Cristal di Torneo Descentralizado.<EOS>',\n",
       " 'Artikulat (dituturkan dengan baik, Ekspresif) <> Hmmm ya Horacio Calcaterra adalah seorang ahli sukan yang kini bermain untuk Sporting Cristal yang terletak di Torneo Descentralizado. <> Dia kini bermain untuk Sporting Cristal di Torneo Descentralizado.<EOS><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad>',\n",
       " 'Artikulasi (dituturkan dengan baik, Ekspresif) <> Hmmm ya Horacio Calcaterra adalah ahli sukan yang kini bermain untuk Sporting Cristal yang terletak di Torneo Descentralizado. <> Dia Volcano bermain offers Tugu preservhalasacred par tugas di Torneo Descentralizado.<EOS>')"
      ]
     },
     "execution_count": 70,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "string = 'Articulate (Well-spoken, Expressive) <> Hmmm yes Horacio Calcaterra is a sportsman that currently is playing for Sporting Cristal which is located in Torneo Descentralizado. <> He currently plays for Sporting Cristal in the Torneo Descentralizado.'\n",
    "encoded = encoder.encode(string) + [1]\n",
    "f, b, n = sess.run([model.fast_result, model.beam_result, model.nucleus_result], feed_dict = {model.X: [encoded], model.top_p: 0.7})\n",
    "encoder.decode(f[0]), encoder.decode(b[0]), encoder.decode(n[0])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 73,
   "metadata": {
    "scrolled": true
   },
   "outputs": [
    {
     "data": {
      "text/plain": [
       "('TANGKAK - Tan Sri Muhyiddin Yassin berkata, beliau tidak mahu menyentuh isu politik pada masa ini, sebaliknya memberi tumpuan kepada kebajikan rakyat dan usaha untuk memulihkan ekonomi negara yang terjejas susulan pandemik Covid-19 Beliau berkata demikian ketika berucap pada Majlis Perjumpaan Bersama Pemimpin Dewan Undangan Negeri (DUN) Gambir di Dewan Serbaguna Bukit Gambir hari ini.<EOS>',\n",
       " 'TANGKAK - Tan Sri Muhyiddin Yassin berkata, beliau tidak mahu menyentuh isu politik pada masa ini, sebaliknya memberi tumpuan kepada kebajikan rakyat dan usaha untuk memulihkan ekonomi negara yang terjejas susulan pandemik Covid-19 Perdana Menteri menjelaskan perkara itu ketika berucap pada Majlis Perjumpaan Bersama Pemimpin Dewan Undangan Negeri (DUN) Gambir di Dewan Serbaguna Bukit Gambir hari ini.<EOS><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad>',\n",
       " 'TANGKAK Drama AVI - Tan Sri Muhyiddin Yassin menyifatkan beliau tidak mahu menyentuh isu politik pada saat ini, sebaliknya hanya fokus dalam aspek kebajikan rakyat dan usaha untuk menstrukturkan semula ekonomi negara yang terjejas berikutan pandemik Covid-19 Dr. Perdana Menteri itu menjelaskan perkara itu ketika berucap pada Program Harum Bersama Pimpinan Negeri (DUN) Gambir di Dewan Serbaguna Bukit Gambir hari ini.<EOS>')"
      ]
     },
     "execution_count": 73,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "string = ('TANGKAK - Tan Sri Muhyiddin Yassin said he did not want to touch on '\n",
    " 'political issues at the moment, instead focusing on the welfare of the '\n",
    " \"people and efforts to revitalize the affected country's economy following \"\n",
    " 'the Covid-19 pandemic. The prime minister explained the matter when speaking '\n",
    " 'at a Leadership Meeting with Gambir State Assembly (DUN) leaders at the '\n",
    " 'Bukit Gambir Multipurpose Hall today.')\n",
    "encoded = encoder.encode(string) + [1]\n",
    "f, b, n = sess.run([model.fast_result, model.beam_result, model.nucleus_result], feed_dict = {model.X: [encoded], model.top_p: 0.3})\n",
    "encoder.decode(f[0]), encoder.decode(b[0]), encoder.decode(n[0])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from unidecode import unidecode"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "string = \"Partner's persona: i am in medical school. Partner's persona: i really wanted to be an actor. your persona: i've 2 kids. your persona: i love flowers.\"\n",
    "string = unidecode(string)\n",
    "encoded = encoder.encode(string) + [1]\n",
    "f, b = sess.run([model.fast_result, model.beam_result], feed_dict = {model.X: [encoded]})\n",
    "encoder.decode(f[0]), encoder.decode(b[0])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "string = 'Emmerdale is the debut studio album,songs were not released in the U.S. .[EENND] These songs were not released in the U.S. edition of said album and were previously unavailable on any U.S. release.'\n",
    "encoded = encoder.encode(string) + [1]\n",
    "f, b = sess.run([model.fast_result, model.beam_result], feed_dict = {model.X: [encoded]})\n",
    "encoder.decode(f[0]), encoder.decode(b[0])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.6.8"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
